{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "low_level_invoke_function.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "FH3IRpYTta2v"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FH3IRpYTta2v"
      },
      "source": [
        "##### Copyright 2019 The IREE Authors"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mWGa71_Ct2ug",
        "cellView": "form"
      },
      "source": [
        "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
        "# See https://llvm.org/LICENSE.txt for license information.\n",
        "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZEmrd07EvthK"
      },
      "source": [
        "# Low Level Invoke Function"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uMVh8_lZDRa7"
      },
      "source": [
        "This notebook shows off some concepts of the low level IREE python bindings."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Go2Nw7BgIHYU",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "0339165a-a35f-4b46-9cf8-f22adc69a7fe"
      },
      "source": [
        "!python -m pip install --pre iree-base-compiler iree-base-runtime -f https://iree.dev/pip-release-links.html"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Looking in links: https://iree.dev/pip-release-links.html\n",
            "Collecting iree-compiler\n",
            "  Downloading https://github.com/iree-org/iree/releases/download/candidate-20220929.281/iree_compiler-20220929.281-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (49.7 MB)\n",
            "\u001b[K     |████████████████████████████████| 49.7 MB 99 kB/s \n",
            "\u001b[?25hCollecting iree-runtime\n",
            "  Downloading https://github.com/iree-org/iree/releases/download/candidate-20220929.281/iree_runtime-20220929.281-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB)\n",
            "\u001b[K     |████████████████████████████████| 2.3 MB 40.7 MB/s \n",
            "\u001b[?25hRequirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from iree-compiler) (6.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from iree-compiler) (1.21.6)\n",
            "Installing collected packages: iree-runtime, iree-compiler\n",
            "Successfully installed iree-compiler-20220929.281 iree-runtime-20220929.281\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1F144M4wAFPz"
      },
      "source": [
        "import numpy as np\n",
        "\n",
        "from iree import runtime as ireert\n",
        "from iree.compiler import compile_str"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2Rq-JdzMAFPU"
      },
      "source": [
        "# Compile a module.\n",
        "SIMPLE_MUL_ASM = \"\"\"\n",
        "  module @arithmetic {\n",
        "    func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {\n",
        "      %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>\n",
        "      return %0 : tensor<4xf32>\n",
        "    } \n",
        "  }\n",
        "\"\"\"\n",
        "\n",
        "# Compile using the vmvx (reference) target:\n",
        "compiled_flatbuffer = compile_str(SIMPLE_MUL_ASM, target_backends=[\"vmvx\"])"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TNQiNeOU_cpK",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4f112fcf-34fe-4d36-c5da-825dc721bdbf"
      },
      "source": [
        "# Register the module with a runtime context.\n",
        "# Use the \"local-task\" CPU driver, which can load the vmvx executable:\n",
        "config = ireert.Config(\"local-task\")\n",
        "ctx = ireert.SystemContext(config=config)\n",
        "vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, compiled_flatbuffer)\n",
        "ctx.add_vm_module(vm_module)\n",
        "\n",
        "# Invoke the function and print the result.\n",
        "print(\"INVOKE simple_mul\")\n",
        "arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)\n",
        "arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)\n",
        "f = ctx.modules.arithmetic[\"simple_mul\"]\n",
        "results = f(arg0, arg1).to_host()\n",
        "print(\"Results:\", results)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "INVOKE simple_mul\n",
            "Results: [ 4. 10. 18. 28.]\n"
          ]
        }
      ]
    }
  ]
}
